{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于BERT Fine-Tune的语义角色标注模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part I. 模型与数据集构建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入相关库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import torch.utils.data as torchData\n",
    "\n",
    "import time\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载pretrained bert model\n",
    "##### def seq2ids(seq) ->  input tensor of model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_pretrained_bert as bert\n",
    "\n",
    "bert_model_dir = \"bert_model/bert-chinese/\"\n",
    "\n",
    "tokenizer = bert.BertTokenizer.from_pretrained(bert_model_dir)\n",
    "\n",
    "bert = bert.BertModel.from_pretrained(bert_model_dir )\n",
    "\n",
    "def seq2ids(seq):\n",
    "\n",
    "    ids = torch.tensor([tokenizer.convert_tokens_to_ids(seq)])\n",
    "\n",
    "    return ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建我们的模型\n",
    "\n",
    "#### 模型在bert上方加入一个全连接层使得bert的输出能够转化为对所有语义标签的输出概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SemanticRoleLabelModel(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, n_hidden,labelnum):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.bert = bert\n",
    "\n",
    "\n",
    "        self.output = torch.nn.Sequential(\n",
    "\n",
    "            torch.nn.Linear(768 , n_hidden),\n",
    "\n",
    "            torch.nn.Tanh(),\n",
    "\n",
    "            torch.nn.Dropout(0.2),\n",
    "\n",
    "            torch.nn.Linear(n_hidden,labelnum),\n",
    "\n",
    "            torch.nn.Tanh()\n",
    "        )\n",
    "\n",
    "        self.prob = torch.nn.Softmax(dim = 1)\n",
    "\n",
    "\n",
    "    def forward(self, b_seqs):\n",
    "\n",
    "        probs = []\n",
    "\n",
    "        for seqs in b_seqs:\n",
    "\n",
    "            layers = self.bert(seqs, output_all_encoded_layers= False)[0]\n",
    "\n",
    "            #lastFourHiddenlayers = layers[-4:]\n",
    "\n",
    "            #concat = torch.cat(lastFourHiddenlayers, dim=2)[0]\n",
    "\n",
    "            output = self.output(layers[0])\n",
    "\n",
    "            prob = self.prob(output)\n",
    "\n",
    "            probs.append(prob)\n",
    "\n",
    "        return probs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构建数据集\n",
    "\n",
    "#### 我们使用845篇文章的前600篇作为训练集 后245篇作为测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import load_seq, split,Extract_Information_By_Re as E\n",
    "\n",
    "class MyDataSet(torchData.Dataset):\n",
    "\n",
    "    def __init__(self,train_or_test = \"train\"):\n",
    "\n",
    "        self.seqs = []\n",
    "\n",
    "        self.labels = []\n",
    "\n",
    "        if train_or_test == \"train\":\n",
    "            \n",
    "            for i in range(600):\n",
    "\n",
    "                seqs, labels = split(*load_seq(i, withlabel=True))\n",
    "\n",
    "                self.seqs += seqs\n",
    "\n",
    "                self.labels += labels\n",
    "\n",
    "        elif train_or_test == \"test\":\n",
    "            \n",
    "            for i in range(600,845):\n",
    "\n",
    "                seqs, labels = split(*load_seq(i, withlabel=True))\n",
    "\n",
    "                self.seqs += seqs\n",
    "\n",
    "                self.labels += labels\n",
    "                \n",
    "    def __len__(self):\n",
    "\n",
    "        return len(self.seqs)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "\n",
    "        x = seq2ids(self.seqs[index]).cuda()\n",
    "\n",
    "        y = E.labels2index(self.labels[index])\n",
    "\n",
    "        return x,y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实例化我们的模型，优化器，损失函数 与 数据集\n",
    " tip: \n",
    "①我们的模型通过GPU计算\n",
    "②损失函数:交叉熵\n",
    "③优化器:Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "Model = SemanticRoleLabelModel(300, len(E.labels))\n",
    "\n",
    "Model = torch.nn.DataParallel(Model).cuda()\n",
    "\n",
    "optimizer = torch.optim.Adam([{'params': Model.parameters()}, ], lr=1e-2)\n",
    "\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "trainingSet = MyDataSet(train_or_test = \"train\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part II. 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: device-side assert triggered",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-23-0a4797d5d4ee>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      9\u001b[0m     \u001b[0mused_seq\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m     \u001b[1;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdataGenerator\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     12\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m         \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32md:\\python36\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    343\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    344\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 345\u001b[1;33m         \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    346\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    347\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32md:\\python36\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    383\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    384\u001b[0m         \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 385\u001b[1;33m         \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    386\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    387\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32md:\\python36\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32md:\\python36\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-10-2d6e11eb4f69>\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m     35\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__getitem__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     36\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 37\u001b[1;33m         \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mseq2ids\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mseqs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     38\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     39\u001b[0m         \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mE\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabels2index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabels\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered"
     ]
    }
   ],
   "source": [
    "dataGenerator = torchData.DataLoader(dataset = trainingSet,batch_size = 10,shuffle = True,collate_fn = lambda data : data)\n",
    "\n",
    "for epoch in range(10):\n",
    "\n",
    "    epoch_loss = [ ]\n",
    "\n",
    "    start_time = time.time()\n",
    "\n",
    "    used_seq = 0\n",
    "    \n",
    "    for data in dataGenerator:\n",
    "\n",
    "        print(data)\n",
    "        b_seq = [item[0] for item in data]\n",
    "\n",
    "        b_label = [item[1] for item in data]\n",
    "        \n",
    "        for seq in b_seq:\n",
    "            \n",
    "            print(seq)\n",
    "            \n",
    "        time_usage = time.time() - start_time\n",
    "        \n",
    "        used_seq += len(b_seq)\n",
    "        \n",
    "        process = str(round( used_seq / len(trainingSet) * 100 , 6)) + \"%\"\n",
    "        \n",
    "        print(\"\\r epoch : {},process : {} ,timeUsage:{}\".format(epoch,process,time_usage),end = \"\",flush=True)\n",
    "\n",
    "        b_prob = Model(b_seq)\n",
    "\n",
    "        b_target = [torch.tensor(label).cuda() for label in b_label]\n",
    "\n",
    "        b_loss = torch.tensor([0.0]).cuda()\n",
    "\n",
    "        for prob, target in zip(b_prob, b_target):\n",
    "\n",
    "            sentence_loss = loss(prob,target)\n",
    "\n",
    "            b_loss += sentence_loss\n",
    "\n",
    "        b_loss /= len(b_seq)\n",
    "\n",
    "        epoch_loss.append(b_loss.item())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        b_loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "\n",
    "    print(\"\\repoch {} : Loss :{} ,timeUsage:{}\".format(epoch, sum(epoch_loss)/len(epoch_loss), time_usage))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
